#!/usr/bin/env python3
"""
Phase 5: Bilateral Projection Exercise

Uses Model 2c coefficients + UN population projections through 2050
to project how bilateral demographic distance (and thus portfolio flow
allocations) will shift as countries age at different rates.

Output:
  - bilateral_projections.csv: top bilateral pair shifts
  - projection_summary_by_country.csv: net projected reallocation by country
"""

import pandas as pd
import numpy as np
from pathlib import Path

BASE_DIR = Path("/mnt/c/demographics_capital_flows/gravity_bilateral")
FOLLOWUP_DIR = Path("/mnt/c/demographics_capital_flows/multilateral/140_country")
OUTPUT_DIR = BASE_DIR / "output" / "tables"

# Model 2c coefficients (from gravity_results.csv, N=92,117)
COEFS = {
    'dZ_1': 1.294765,
    'dZ_2': -0.161998,
    'dZ_3': 0.006172,
    'dZ_1_x_kaopen_j': 0.966978,
    'dZ_2_x_kaopen_j': -0.115754,
    'dZ_3_x_kaopen_j': 0.003939,
    'kaopen_j': 0.191148,
}

# Target projection years
PROJ_YEARS = [2024, 2030, 2040, 2050]


def load_demographics():
    """Load full panel with projections."""
    fp = pd.read_csv(FOLLOWUP_DIR / "data" / "processed" / "full_panel.csv")
    cols = ['iso3', 'year', 'Z_1', 'Z_2', 'Z_3', 'kaopen', 'ngdp_usd',
            'country_name'] if 'country_name' in fp.columns else \
           ['iso3', 'year', 'Z_1', 'Z_2', 'Z_3', 'kaopen', 'ngdp_usd']
    available = [c for c in cols if c in fp.columns]
    fp = fp[available].copy()

    # For projections, Z variables are available; KAOPEN and GDP are not
    # Use latest observed KAOPEN and GDP (2024 or most recent)
    latest = fp[fp['year'] <= 2024].sort_values('year').groupby('iso3').last()
    kaopen_map = latest['kaopen'].to_dict() if 'kaopen' in latest.columns else {}
    gdp_map = latest['ngdp_usd'].to_dict() if 'ngdp_usd' in latest.columns else {}

    return fp, kaopen_map, gdp_map


def compute_bilateral_demographic_effect(z_i, z_j, kaopen_j):
    """Compute the demographic component of log(Flow_ij) from Model 2c."""
    dZ_1 = z_i['Z_1'] - z_j['Z_1']
    dZ_2 = z_i['Z_2'] - z_j['Z_2']
    dZ_3 = z_i['Z_3'] - z_j['Z_3']

    effect = (COEFS['dZ_1'] * dZ_1 +
              COEFS['dZ_2'] * dZ_2 +
              COEFS['dZ_3'] * dZ_3)

    if not np.isnan(kaopen_j):
        effect += (COEFS['dZ_1_x_kaopen_j'] * dZ_1 * kaopen_j +
                   COEFS['dZ_2_x_kaopen_j'] * dZ_2 * kaopen_j +
                   COEFS['dZ_3_x_kaopen_j'] * dZ_3 * kaopen_j +
                   COEFS['kaopen_j'] * kaopen_j)

    return effect


def main():
    print("=" * 70)
    print("PHASE 5: BILATERAL PROJECTION EXERCISE")
    print("=" * 70)

    fp, kaopen_map, gdp_map = load_demographics()

    # Get name map if available
    name_col = 'country_name' if 'country_name' in fp.columns else None
    if name_col:
        name_map = fp.dropna(subset=[name_col]).groupby('iso3')[name_col].first().to_dict()
    else:
        name_map = {}

    # Select key countries for bilateral analysis
    # Major economies + demographically interesting cases
    key_countries = [
        # Aging AEs
        'JPN', 'DEU', 'ITA', 'KOR', 'ESP', 'GBR', 'FRA', 'USA', 'CAN', 'AUS',
        # Large EMs
        'CHN', 'IND', 'BRA', 'MEX', 'IDN', 'TUR', 'THA', 'VNM',
        # Young/growing
        'NGA', 'EGY', 'PHL', 'BGD', 'PAK', 'KEN', 'ETH',
        # Other significant
        'SAU', 'ARE', 'SGP', 'ZAF', 'RUS', 'POL', 'MYS',
    ]

    # Filter to countries with Z data
    available_countries = set(fp[fp['year'].isin(PROJ_YEARS) & fp['Z_1'].notna()]['iso3'].unique())
    key_countries = [c for c in key_countries if c in available_countries]
    print(f"\nKey countries with projection data: {len(key_countries)}")

    # Compute bilateral demographic effects for each projection year
    results = []
    for year in PROJ_YEARS:
        year_data = fp[fp['year'] == year].set_index('iso3')

        for i_iso in key_countries:
            if i_iso not in year_data.index:
                continue
            z_i = year_data.loc[i_iso]
            kaopen_i = kaopen_map.get(i_iso, np.nan)
            gdp_i = gdp_map.get(i_iso, np.nan)

            for j_iso in key_countries:
                if j_iso == i_iso or j_iso not in year_data.index:
                    continue
                z_j = year_data.loc[j_iso]
                kaopen_j = kaopen_map.get(j_iso, np.nan)

                effect = compute_bilateral_demographic_effect(z_i, z_j, kaopen_j)

                results.append({
                    'year': year,
                    'reporter': i_iso,
                    'partner': j_iso,
                    'reporter_name': name_map.get(i_iso, i_iso),
                    'partner_name': name_map.get(j_iso, j_iso),
                    'dZ_1': z_i['Z_1'] - z_j['Z_1'],
                    'dZ_2': z_i['Z_2'] - z_j['Z_2'],
                    'dZ_3': z_i['Z_3'] - z_j['Z_3'],
                    'kaopen_j': kaopen_j,
                    'demo_effect': effect,
                })

    proj_df = pd.DataFrame(results)
    print(f"Computed {len(proj_df):,} bilateral projections")

    # Compute changes relative to 2024 baseline
    baseline = proj_df[proj_df['year'] == 2024][['reporter', 'partner', 'demo_effect']].rename(
        columns={'demo_effect': 'effect_2024'})
    proj_df = proj_df.merge(baseline, on=['reporter', 'partner'], how='left')
    proj_df['delta_effect'] = proj_df['demo_effect'] - proj_df['effect_2024']
    # Convert log-point change to approximate percentage change
    proj_df['pct_change'] = (np.exp(proj_df['delta_effect']) - 1) * 100

    # ===================================================================
    # Top bilateral shifts by 2050
    # ===================================================================
    proj_2050 = proj_df[proj_df['year'] == 2050].copy()
    proj_2050['abs_delta'] = proj_2050['delta_effect'].abs()

    print(f"\n{'=' * 70}")
    print("TOP 20 BILATERAL SHIFTS BY 2050 (vs 2024)")
    print(f"{'=' * 70}")

    top_increases = proj_2050.nlargest(20, 'delta_effect')
    print("\nLargest INCREASES in demographic flow pressure (old→young intensifying):")
    print(f"  {'Reporter':<6} {'→':>2} {'Partner':<6}  {'Δlog':>8}  {'%Δ':>8}  {'ΔZ₁_2050':>9}")
    for _, row in top_increases.iterrows():
        print(f"  {row['reporter']:<6} → {row['partner']:<6}  {row['delta_effect']:>+8.3f}  "
              f"{row['pct_change']:>+7.1f}%  {row['dZ_1']:>+9.3f}")

    top_decreases = proj_2050.nsmallest(20, 'delta_effect')
    print("\nLargest DECREASES in demographic flow pressure (convergence/reversal):")
    print(f"  {'Reporter':<6} {'→':>2} {'Partner':<6}  {'Δlog':>8}  {'%Δ':>8}  {'ΔZ₁_2050':>9}")
    for _, row in top_decreases.iterrows():
        print(f"  {row['reporter']:<6} → {row['partner']:<6}  {row['delta_effect']:>+8.3f}  "
              f"{row['pct_change']:>+7.1f}%  {row['dZ_1']:>+9.3f}")

    # ===================================================================
    # Time path for selected pairs
    # ===================================================================
    interesting_pairs = [
        ('JPN', 'IND'), ('JPN', 'NGA'), ('JPN', 'IDN'),
        ('DEU', 'IND'), ('DEU', 'NGA'), ('DEU', 'TUR'),
        ('KOR', 'VNM'), ('KOR', 'PHL'), ('KOR', 'IND'),
        ('USA', 'IND'), ('USA', 'NGA'), ('USA', 'MEX'),
        ('CHN', 'IND'), ('CHN', 'NGA'), ('GBR', 'IND'),
    ]

    print(f"\n{'=' * 70}")
    print("TIME PATHS FOR SELECTED PAIRS")
    print(f"{'=' * 70}")
    print(f"  {'Pair':<12}  {'2024':>8}  {'2030':>8}  {'2040':>8}  {'2050':>8}  {'%Δ 24-50':>10}")
    for i_iso, j_iso in interesting_pairs:
        pair_data = proj_df[(proj_df['reporter'] == i_iso) & (proj_df['partner'] == j_iso)]
        if len(pair_data) < 2:
            continue
        vals = {}
        for _, row in pair_data.iterrows():
            vals[int(row['year'])] = row['demo_effect']
        pct = (np.exp(vals.get(2050, 0) - vals.get(2024, 0)) - 1) * 100 if 2024 in vals and 2050 in vals else np.nan
        print(f"  {i_iso}→{j_iso:<6}  " + "  ".join(f"{vals.get(y, np.nan):>+8.3f}" for y in PROJ_YEARS)
              + f"  {pct:>+9.1f}%")

    # ===================================================================
    # Country-level net reallocation pressure
    # ===================================================================
    print(f"\n{'=' * 70}")
    print("NET REALLOCATION PRESSURE BY COUNTRY (2050 vs 2024)")
    print(f"{'=' * 70}")

    # For each country, sum delta_effect as reporter (outward) and as partner (inward)
    outward = proj_2050.groupby('reporter')['delta_effect'].mean().rename('avg_outward_delta')
    inward = proj_2050.groupby('partner')['delta_effect'].mean().rename('avg_inward_delta')
    # Inward: when others invest MORE in you, their delta_effect is positive
    # We want the average change in how much others want to invest in country j
    inward_to_j = proj_2050.groupby('partner')['delta_effect'].mean().rename('avg_received_delta')

    country_summary = pd.DataFrame({'avg_outward_delta': outward, 'avg_received_delta': inward_to_j})
    country_summary['net_pressure'] = country_summary['avg_outward_delta'] - country_summary['avg_received_delta']
    country_summary = country_summary.sort_values('net_pressure', ascending=False)

    print(f"\n  {'Country':<6}  {'Outward Δ':>10}  {'Received Δ':>10}  {'Net':>10}  {'Interpretation'}")
    print(f"  {'-' * 65}")
    for iso, row in country_summary.iterrows():
        interp = "aging → more outward" if row['net_pressure'] > 0.05 else \
                 "young → more received" if row['net_pressure'] < -0.05 else "balanced"
        print(f"  {iso:<6}  {row['avg_outward_delta']:>+10.3f}  {row['avg_received_delta']:>+10.3f}  "
              f"{row['net_pressure']:>+10.3f}  {interp}")

    # ===================================================================
    # Save results
    # ===================================================================
    # Full projections
    proj_df.to_csv(OUTPUT_DIR / "bilateral_projections.csv", index=False)
    print(f"\nSaved: {OUTPUT_DIR / 'bilateral_projections.csv'}")

    # Country summary
    country_summary.index.name = 'iso3'
    country_summary.to_csv(OUTPUT_DIR / "projection_summary_by_country.csv")
    print(f"Saved: {OUTPUT_DIR / 'projection_summary_by_country.csv'}")

    # Top shifts table for paper
    paper_table = []
    for _, row in proj_2050.nlargest(10, 'delta_effect').iterrows():
        paper_table.append({
            'reporter': row['reporter'],
            'partner': row['partner'],
            'dZ1_2024': row['effect_2024'] / COEFS['dZ_1'] if abs(COEFS['dZ_1']) > 0 else np.nan,  # rough
            'delta_log_flow': row['delta_effect'],
            'pct_change': row['pct_change'],
        })
    for _, row in proj_2050.nsmallest(10, 'delta_effect').iterrows():
        paper_table.append({
            'reporter': row['reporter'],
            'partner': row['partner'],
            'dZ1_2024': row['effect_2024'] / COEFS['dZ_1'] if abs(COEFS['dZ_1']) > 0 else np.nan,
            'delta_log_flow': row['delta_effect'],
            'pct_change': row['pct_change'],
        })
    pd.DataFrame(paper_table).to_csv(OUTPUT_DIR / "top_bilateral_shifts_2050.csv", index=False)
    print(f"Saved: {OUTPUT_DIR / 'top_bilateral_shifts_2050.csv'}")

    return proj_df, country_summary


if __name__ == "__main__":
    proj_df, country_summary = main()
